{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Result diversification with Elasticsearch\n",
    "This notebook demonstrates:\n",
    "1. Loading fashion dataset\n",
    "2. Index in Elasticsearch using image search\n",
    "3. Search items with a broad search term\n",
    "4. Apply result diversification with the MMR algorithm to the results.\n",
    "\n",
    "Check out our [blog post](https://www.elastic.co/search-labs/blog/diversify-results-maximum-marginal-relevance) on this topic to learn more about "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Setup and Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
     ]
    }
   ],
   "source": [
    "!pip install -r requirements.txt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import requests\n",
    "import numpy as np\n",
    "import kagglehub\n",
    "from itertools import repeat\n",
    "from concurrent.futures import ThreadPoolExecutor\n",
    "from tqdm import tqdm\n",
    "import time\n",
    "from elasticsearch import Elasticsearch\n",
    "from IPython.display import HTML, display\n",
    "from typing import List, Dict, Tuple"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load Configuration\n",
    "\n",
    "Create a configuration file `elastic_config.env` in this format to authenticate with JINA and the Elastic Cluster. \n",
    "```\n",
    "ELASTIC_API_KEY=<ELASTIC_KEY>\n",
    "ELASTIC_HOST=<HOST_URL>\n",
    "JINA_API_KEY=<JINA_KEY>\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Configuration loaded successfully\n"
     ]
    }
   ],
   "source": [
    "def load_config(file_path=\"elastic_config.env\"):\n",
    "    \"\"\"Load configuration from environment file\"\"\"\n",
    "    config = {}\n",
    "    try:\n",
    "        with open(file_path, \"r\") as file:\n",
    "            for line in file:\n",
    "                if \"=\" in line:\n",
    "                    key, value = line.strip().split(\"=\", 1)\n",
    "                    config[key] = value\n",
    "    except FileNotFoundError:\n",
    "        print(f\"Configuration file not found: {file_path}\")\n",
    "    return config\n",
    "\n",
    "\n",
    "config = load_config()\n",
    "elastic_host = config.get(\"ELASTIC_HOST\")\n",
    "elastic_api_key = config.get(\"ELASTIC_API_KEY\")\n",
    "jina_api_key = config.get(\"JINA_API_KEY\")\n",
    "\n",
    "print(\"Configuration loaded successfully\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Load Dataset and Extract ID & Image URLs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Path to dataset files: /Users/peter/.cache/kagglehub/datasets/paramaggarwal/fashion-product-images-dataset/versions/1\n",
      "Loaded 44446 total products\n",
      "\n",
      "Filtered to 2694 bottomwear products\n"
     ]
    }
   ],
   "source": [
    "dataset_path = kagglehub.dataset_download(\n",
    "    \"paramaggarwal/fashion-product-images-dataset\"\n",
    ")\n",
    "print(\"Path to dataset files:\", dataset_path)\n",
    "\n",
    "styles_folder = os.path.join(dataset_path, \"fashion-dataset/styles\")\n",
    "\n",
    "\n",
    "def load_dataset(folder_path):\n",
    "    \"\"\"Load all JSON files from the dataset folder\"\"\"\n",
    "    products = []\n",
    "\n",
    "    for filename in os.listdir(folder_path):\n",
    "        if filename.endswith(\".json\"):\n",
    "            file_path = os.path.join(folder_path, filename)\n",
    "            try:\n",
    "                with open(file_path, \"r\") as f:\n",
    "                    data = json.load(f)\n",
    "                    if \"data\" in data:\n",
    "                        products.append(data[\"data\"])\n",
    "            except Exception as e:\n",
    "                print(f\"Error reading {filename}: {e}\")\n",
    "\n",
    "    return products\n",
    "\n",
    "\n",
    "products = load_dataset(styles_folder)\n",
    "print(f\"Loaded {len(products)} total products\")\n",
    "\n",
    "# Filter for bottomwear only to limit data for this demo\n",
    "bottomwear_products = []\n",
    "for product in products:\n",
    "    sub_category = product.get(\"subCategory\", {})\n",
    "    if sub_category.get(\"typeName\", \"\").lower() == \"bottomwear\":\n",
    "        bottomwear_products.append(product)\n",
    "\n",
    "print(f\"\\nFiltered to {len(bottomwear_products)} bottomwear products\")\n",
    "\n",
    "products = bottomwear_products"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Extracted 2693 products with valid IDs and image URLs\n",
      "\n",
      "Limited to 1000 items for demo\n",
      "\n",
      "Sample items (alphabetically sorted):\n",
      "  - Femella Women Off White Shorts (Shorts, Off White)\n",
      "  - Nike Women Strong Poly Black Capri (Capris, Black)\n",
      "  - Flying Machine Men Blue Jeans (Jeans, Blue)\n",
      "  - Urban Yoga Men Black Shorts (Shorts, Black)\n",
      "  - Doodle Girls Lace Bow LT.Pink Leggings (Leggings, Pink)\n"
     ]
    }
   ],
   "source": [
    "def extract_id_and_image_url(products):\n",
    "    \"\"\"Extract ID and image URL from products\"\"\"\n",
    "    image_data = []\n",
    "\n",
    "    for product in products:\n",
    "        product_id = product.get(\"id\")\n",
    "\n",
    "        style_images = product.get(\"styleImages\", {})\n",
    "        default_image = style_images.get(\"default\", {})\n",
    "\n",
    "        image_url = default_image.get(\"resolutions\", {}).get(\"360X480\", \"\")\n",
    "        if not image_url:\n",
    "            image_url = default_image.get(\"imageURL\", \"\")\n",
    "\n",
    "        if product_id and image_url:\n",
    "            image_data.append(\n",
    "                {\n",
    "                    \"id\": product_id,\n",
    "                    \"image_url\": image_url,\n",
    "                    \"product_name\": product.get(\"productDisplayName\", \"\"),\n",
    "                    \"brand\": product.get(\"brandName\", \"\"),\n",
    "                    \"color\": product.get(\"baseColour\", \"\"),\n",
    "                    \"article_type\": product.get(\"articleType\", {}).get(\"typeName\", \"\"),\n",
    "                }\n",
    "            )\n",
    "\n",
    "    return image_data\n",
    "\n",
    "\n",
    "image_data = extract_id_and_image_url(products)\n",
    "print(f\"Extracted {len(image_data)} products with valid IDs and image URLs\")\n",
    "\n",
    "# Only use 1000 products to not make the demo too heavy\n",
    "demo_image_data = image_data[:1000]\n",
    "print(f\"\\nLimited to {len(demo_image_data)} items for demo\")\n",
    "print(f\"\\nSample items (alphabetically sorted):\")\n",
    "for i in range(min(5, len(demo_image_data))):\n",
    "    item = demo_image_data[i]\n",
    "    print(f\"  - {item['product_name']} ({item['article_type']}, {item['color']})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Create Image Embeddings with JINA API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Getting embeddings...\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Getting embeddings: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [05:25<00:00,  3.08it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Retrieved 1000 embeddings\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def get_single_image_embedding(item, jina_api_key):\n",
    "    \"\"\"Get embedding for a single image\"\"\"\n",
    "    url = \"https://api.jina.ai/v1/embeddings\"\n",
    "    headers = {\n",
    "        \"Content-Type\": \"application/json\",\n",
    "        \"Authorization\": f\"Bearer {jina_api_key}\",\n",
    "    }\n",
    "\n",
    "    product_data = {\n",
    "        \"product_name\": item[\"product_name\"],\n",
    "        \"brand\": item[\"brand\"],\n",
    "        \"color\": item[\"color\"],\n",
    "        \"article_type\": item[\"article_type\"],\n",
    "    }\n",
    "\n",
    "    data = {\n",
    "        \"model\": \"jina-embeddings-v4\",\n",
    "        \"dimensions\": 1024,\n",
    "        \"normalized\": True,\n",
    "        \"task\": \"retrieval.passage\",\n",
    "        \"embedding_type\": \"float\",\n",
    "        \"input\": [{\"text\": f\"{product_data}\"}, {\"image\": item[\"image_url\"]}],\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        response = requests.post(url, headers=headers, json=data, timeout=200)\n",
    "        response.raise_for_status()\n",
    "\n",
    "        result = response.json()\n",
    "        if \"data\" in result and len(result[\"data\"]) > 0:\n",
    "            return {\n",
    "                \"id\": item[\"id\"],\n",
    "                \"image_url\": item[\"image_url\"],\n",
    "                \"product_name\": item[\"product_name\"],\n",
    "                \"brand\": item[\"brand\"],\n",
    "                \"color\": item[\"color\"],\n",
    "                \"article_type\": item[\"article_type\"],\n",
    "                \"image_vector\": to_avg_vector(\n",
    "                    [result[\"data\"][0][\"embedding\"], result[\"data\"][1][\"embedding\"]]\n",
    "                ),\n",
    "            }\n",
    "        return None\n",
    "    except Exception as e:\n",
    "        print(f\"Error processing {item}: {e}\")\n",
    "        return None\n",
    "\n",
    "\n",
    "# encode image and product information in one vector\n",
    "def to_avg_vector(vectors):\n",
    "    vectors_array = np.array(vectors)\n",
    "\n",
    "    avg_vector = np.mean(vectors_array, axis=0)\n",
    "\n",
    "    norm = np.linalg.norm(avg_vector)\n",
    "    if norm > 0:\n",
    "        normalized_avg_vector = avg_vector / norm\n",
    "    else:\n",
    "        normalized_avg_vector = avg_vector\n",
    "\n",
    "    return normalized_avg_vector.tolist()\n",
    "\n",
    "\n",
    "print(\"Getting embeddings...\")\n",
    "\n",
    "with ThreadPoolExecutor(max_workers=10) as executor:\n",
    "    products_with_vectors = list(\n",
    "        tqdm(\n",
    "            executor.map(\n",
    "                get_single_image_embedding, demo_image_data, repeat(jina_api_key)\n",
    "            ),\n",
    "            total=len(demo_image_data),\n",
    "            desc=\"Getting embeddings\",\n",
    "        )\n",
    "    )\n",
    "\n",
    "print(f\"Retrieved {len(products_with_vectors)} embeddings\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Original products: 1000\n",
      "After filtering similar items: 758\n",
      "Removed 242 similar items\n"
     ]
    }
   ],
   "source": [
    "def _cosine_similarity(X, Y):\n",
    "    \"\"\"Compute cosine similarity between two sets of vectors.\"\"\"\n",
    "    X = np.array(X)\n",
    "    Y = np.array(Y)\n",
    "\n",
    "    if X.ndim == 1:\n",
    "        X = X.reshape(1, -1)\n",
    "    if Y.ndim == 1:\n",
    "        Y = Y.reshape(1, -1)\n",
    "\n",
    "    # Normalize the vectors\n",
    "    X_norm = X / np.linalg.norm(X, axis=1, keepdims=True)\n",
    "    Y_norm = Y / np.linalg.norm(Y, axis=1, keepdims=True)\n",
    "\n",
    "    return np.dot(X_norm, Y_norm.T)\n",
    "\n",
    "\n",
    "def filter_out_similar_items(items, threshold=0.98):\n",
    "    \"\"\"Filter out items that have very high similarity to previously seen items\"\"\"\n",
    "    filtered_items = []\n",
    "\n",
    "    for i, item1 in enumerate(items):\n",
    "        is_similar_to_existing = False\n",
    "\n",
    "        for existing_item in filtered_items:\n",
    "            similarity = _cosine_similarity(\n",
    "                [item1[\"image_vector\"]], [existing_item[\"image_vector\"]]\n",
    "            )[0][0]\n",
    "\n",
    "            if similarity >= threshold:\n",
    "                is_similar_to_existing = True\n",
    "                break\n",
    "\n",
    "        if not is_similar_to_existing:\n",
    "            filtered_items.append(item1)\n",
    "\n",
    "    return filtered_items\n",
    "\n",
    "\n",
    "# Filter out items with similarity >= 0.98\n",
    "filtered_products = filter_out_similar_items(products_with_vectors, threshold=0.98)\n",
    "\n",
    "print(f\"Original products: {len(products_with_vectors)}\")\n",
    "print(f\"After filtering similar items: {len(filtered_products)}\")\n",
    "print(f\"Removed {len(products_with_vectors) - len(filtered_products)} similar items\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Setup Elasticsearch Index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Deleted existing index 'fashion_images'\n",
      "Created index 'fashion_images'\n"
     ]
    }
   ],
   "source": [
    "# Initialize Elasticsearch client\n",
    "es = Elasticsearch(elastic_host, api_key=elastic_api_key)\n",
    "\n",
    "# Define index name\n",
    "index_name = \"fashion_images\"\n",
    "\n",
    "# Define index mapping\n",
    "mapping = {\n",
    "    \"mappings\": {\n",
    "        \"properties\": {\n",
    "            \"id\": {\"type\": \"keyword\"},\n",
    "            \"image_url\": {\"type\": \"keyword\"},\n",
    "            \"product_name\": {\"type\": \"keyword\"},\n",
    "            \"brand\": {\"type\": \"keyword\"},\n",
    "            \"color\": {\"type\": \"keyword\"},\n",
    "            \"article_type\": {\"type\": \"keyword\"},\n",
    "            \"image_vector\": {\n",
    "                \"type\": \"dense_vector\",\n",
    "                \"dims\": 1024,\n",
    "                \"index\": True,\n",
    "                \"similarity\": \"cosine\",\n",
    "                \"index_options\": {\"type\": \"flat\"},\n",
    "            },\n",
    "        }\n",
    "    }\n",
    "}\n",
    "\n",
    "if es.indices.exists(index=index_name):\n",
    "    es.indices.delete(index=index_name)\n",
    "    print(f\"Deleted existing index '{index_name}'\")\n",
    "\n",
    "es.indices.create(index=index_name, body=mapping)\n",
    "print(f\"Created index '{index_name}'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Index Documents with Image Vectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "start\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Indexing images: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 758/758 [00:26<00:00, 28.72it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Successfully indexed 758 documents\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "def index_single_image(item):\n",
    "    try:\n",
    "        es.index(index=index_name, id=item[\"id\"], document=item)\n",
    "        return 1\n",
    "    except Exception as e:\n",
    "        print(f\"Error indexing document {item['id']}: {e}\")\n",
    "        return 0\n",
    "\n",
    "\n",
    "print(\"start\")\n",
    "\n",
    "# Index the documents in parallel\n",
    "with ThreadPoolExecutor(max_workers=10) as executor:\n",
    "    results = list(\n",
    "        tqdm(\n",
    "            executor.map(index_single_image, filtered_products),\n",
    "            total=len(filtered_products),\n",
    "            desc=\"Indexing images\",\n",
    "        )\n",
    "    )\n",
    "\n",
    "indexed_count = sum(results)\n",
    "print(f\"Successfully indexed {indexed_count} documents\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Query Images with Text Search"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Creating text embedding for: 'pants'\n",
      "\n",
      "Searching for items similar to: 'pants'\n",
      "Found 150 similar images\n"
     ]
    }
   ],
   "source": [
    "SEARCH_QUERY = \"pants\"\n",
    "\n",
    "\n",
    "def get_text_embedding(text, jina_api_key):\n",
    "    \"\"\"Get text embedding from JINA API\"\"\"\n",
    "    url = \"https://api.jina.ai/v1/embeddings\"\n",
    "    headers = {\n",
    "        \"Content-Type\": \"application/json\",\n",
    "        \"Authorization\": f\"Bearer {jina_api_key}\",\n",
    "    }\n",
    "\n",
    "    data = {\n",
    "        \"model\": \"jina-embeddings-v4\",\n",
    "        \"dimensions\": 1024,\n",
    "        \"normalized\": True,\n",
    "        \"embedding_type\": \"float\",\n",
    "        \"task\": \"retrieval.query\",\n",
    "        \"input\": [{\"text\": text}],\n",
    "    }\n",
    "\n",
    "    try:\n",
    "        response = requests.post(url, headers=headers, json=data, timeout=30)\n",
    "        response.raise_for_status()\n",
    "        result = response.json()\n",
    "\n",
    "        if \"data\" in result and len(result[\"data\"]) > 0:\n",
    "            return result[\"data\"][0][\"embedding\"]\n",
    "    except Exception as e:\n",
    "        print(f\"Error getting text embedding: {e}\")\n",
    "\n",
    "    return None\n",
    "\n",
    "\n",
    "def search_similar_images(es, index_name, query_vector, k=20):\n",
    "    \"\"\"Search for similar images using vector similarity\"\"\"\n",
    "    query = {\n",
    "        \"knn\": {\n",
    "            \"field\": \"image_vector\",\n",
    "            \"query_vector\": query_vector,\n",
    "            \"k\": k,\n",
    "        },\n",
    "        \"size\": k,\n",
    "    }\n",
    "\n",
    "    response = es.search(index=index_name, body=query)\n",
    "\n",
    "    results = []\n",
    "    for hit in response[\"hits\"][\"hits\"]:\n",
    "        # Find the original product data to get additional info\n",
    "        product_id = hit[\"_source\"][\"id\"]\n",
    "\n",
    "        results.append(\n",
    "            {\n",
    "                \"id\": product_id,\n",
    "                \"image_url\": hit[\"_source\"][\"image_url\"],\n",
    "                \"image_vector\": hit[\"_source\"][\"image_vector\"],\n",
    "                \"score\": hit[\"_score\"],\n",
    "                \"product_name\": hit[\"_source\"][\"product_name\"],\n",
    "                \"brand\": hit[\"_source\"][\"brand\"],\n",
    "                \"color\": hit[\"_source\"][\"color\"],\n",
    "                \"article_type\": hit[\"_source\"][\"article_type\"],\n",
    "            }\n",
    "        )\n",
    "\n",
    "    return results\n",
    "\n",
    "\n",
    "print(f\"Creating text embedding for: '{SEARCH_QUERY}'\")\n",
    "query_vector = get_text_embedding(SEARCH_QUERY, jina_api_key)\n",
    "\n",
    "if query_vector:\n",
    "    print(f\"\\nSearching for items similar to: '{SEARCH_QUERY}'\")\n",
    "    search_results = search_similar_images(es, index_name, query_vector, k=150)\n",
    "    print(f\"Found {len(search_results)} similar images\")\n",
    "else:\n",
    "    print(\"Failed to get text embedding\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Display Search Results\\n\\nShowing results for text search: **\"pants\"**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<h2>Original Search Results</h2><div style=\"display: flex; flex-wrap: wrap; gap: 10px;\">\n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/33fa22c1b481c6ffc459f7374e45a8c4_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 9785</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Urban Yoga Women Summer Bottoms Navy Blue Track Pants\">Urban Yoga Women Summer B...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Navy Blue</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.862</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/2be067c2ea4c63bfe36dda593da764ad_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 7128</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Urban Yoga Women Bottom Black Track Pant\">Urban Yoga Women Bottom B...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Black</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.861</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/3d5b733d71b39c4c07f92ee30f042980_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 19242</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Puma Women Grey Capri Pants\">Puma Women Grey Capri Pan...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Capris - Grey</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.858</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/839d586ad24f425ef60cc49e71419ebb_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 3921</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Urban Yoga Men's Bottom Black Track Pant\">Urban Yoga Men's Bottom B...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Black</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.857</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/b80233197f4b5331aa5122f6eff3a95b_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 52529</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Pepe Jeans Men Grey 3/4 Length Pants\">Pepe Jeans Men Grey 3/4 L...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Shorts - Grey</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.856</p>\n",
       "       </div>\n",
       "       </div><div style=\"display: flex; flex-wrap: wrap; gap: 10px;\">\n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/2946e88d72bd1e297f5a1c70451cea02_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 4826</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"ADIDAS Men's Woven Dark Navy White Track Pant\">ADIDAS Men's Woven Dark N...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Navy Blue</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.855</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/15b6662c5e90be8ddbdeb55b7add5271_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 44664</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Wills Lifestyle Women Charcoal Trousers\">Wills Lifestyle Women Cha...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Trousers - Charcoal</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.854</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/1adabb7afa1895e26187de51dc97f50d_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 7133</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Urban Yoga Men Bottom Grey Yoga Pants\">Urban Yoga Men Bottom Gre...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Grey</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.854</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/French-Connection-Women-Beige-Trouser_e9389131527c43b1b310189bcf4b5552_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 43522</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"French Connection Women Navy Trouser\">French Connection Women N...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Trousers - Navy Blue</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.854</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/0c103d57e4f45b1676e30f781fa77383_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 18869</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Puma Women Black Core Track Pants\">Puma Women Black Core Tra...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Black</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.854</p>\n",
       "       </div>\n",
       "       </div><div style=\"display: flex; flex-wrap: wrap; gap: 10px;\"></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "def display_images(images, title=\"Images\", max_per_row=5):\n",
    "    \"\"\"Display images in a grid layout\"\"\"\n",
    "    html = f\"<h2>{title}</h2>\"\n",
    "    html += '<div style=\"display: flex; flex-wrap: wrap; gap: 10px;\">'\n",
    "    images = images[:10]\n",
    "\n",
    "    for i, img in enumerate(images):\n",
    "        score = img.get(\"score\", \"N/A\")\n",
    "        if isinstance(score, (int, float)):\n",
    "            score_str = f\"{score:.3f}\"\n",
    "        else:\n",
    "            score_str = \"N/A\"\n",
    "\n",
    "        product_name = img.get(\"product_name\", \"N/A\")\n",
    "        if product_name != \"N/A\" and len(product_name) > 25:\n",
    "            product_name = product_name[:25] + \"...\"\n",
    "\n",
    "        html += f\"\"\"\n",
    "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
    "           <img src=\"{img['image_url']}\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
    "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: {img['id']}</p>\n",
    "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"{img.get('product_name', 'N/A')}\">{product_name}</p>\n",
    "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">{img.get('article_type', '')} - {img.get('color', '')}</p>\n",
    "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: {score_str}</p>\n",
    "       </div>\n",
    "       \"\"\"\n",
    "\n",
    "        if (i + 1) % max_per_row == 0:\n",
    "            html += '</div><div style=\"display: flex; flex-wrap: wrap; gap: 10px;\">'\n",
    "\n",
    "    html += \"</div>\"\n",
    "    display(HTML(html))\n",
    "\n",
    "\n",
    "display_images(search_results, \"Original Search Results\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Reranking with Maximum Marginal Relevance (MMR)\n",
    "MMR is a diversity-promoting algorithm that balances:\n",
    "\n",
    "**Relevance**: How well items match the query  \n",
    "**Diversity**: How different items are from each other  \n",
    "The algorithm iteratively selects items that are relevant to the query but different from already selected items."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<h2>Reranked Results (MMR)</h2><div style=\"display: flex; flex-wrap: wrap; gap: 10px;\">\n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/33fa22c1b481c6ffc459f7374e45a8c4_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 9785</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Urban Yoga Women Summer Bottoms Navy Blue Track Pants\">Urban Yoga Women Summer B...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Navy Blue</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.862</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/87adc16694f6d4ee2c28a1233c57d5c9_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 41163</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Allen Solly Woman Khaki Trousers\">Allen Solly Woman Khaki T...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Trousers - Khaki</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.839</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/d33da6aaac593337751b91e04938740f_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 13255</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Palm Tree Kids Boys Check White Shorts\">Palm Tree Kids Boys Check...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Shorts - White</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.835</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/603270a4ee6b0e6e03e7195dba64bc44_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 4774</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"ADIDAS Women 3S Pink Track Pants\">ADIDAS Women 3S Pink Trac...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Track Pants - Pink</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.837</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/b80233197f4b5331aa5122f6eff3a95b_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 52529</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Pepe Jeans Men Grey 3/4 Length Pants\">Pepe Jeans Men Grey 3/4 L...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Shorts - Grey</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.856</p>\n",
       "       </div>\n",
       "       </div><div style=\"display: flex; flex-wrap: wrap; gap: 10px;\">\n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/74c513a0a3ed8a2b22e1d8e2bce9bd83_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 22466</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Myntra Women Cream Patiala Salwar\">Myntra Women Cream Patial...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Leggings - Cream</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.836</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/1fc7af016c7f50bf24f9ebdf5144ceed_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 44906</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Puma Men White 3/4 Length Pants\">Puma Men White 3/4 Length...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Shorts - White</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.853</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/8c3407346f5b92621c279c7a8ab2fe0b_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 32406</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Arrow Woman Black Trousers\">Arrow Woman Black Trouser...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Trousers - Black</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.853</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/United-Colors-of-Benetton-Green-Trouser_29005e896e9d76a457a7f1c280ca4448_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 57824</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"United Colors of Benetton Green Trouser\">United Colors of Benetton...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Trousers - Green</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.842</p>\n",
       "       </div>\n",
       "       \n",
       "       <div style=\"text-align: center; margin-bottom: 20px;\">\n",
       "           <img src=\"http://assets.myntassets.com/h_480,q_95,w_360/v1/images/style/properties/764eeb797cd21798eb5b9e91cc9f9ae0_images.jpg\" style=\"width: 150px; height: 200px; object-fit: cover; border: 1px solid #ddd;\">\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; font-weight: bold;\">ID: 30919</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px;\" alit=\"Fabindia Women Pink Harem Pants\">Fabindia Women Pink Harem...</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 11px; color: #666;\">Trousers - Pink</p>\n",
       "           <p style=\"margin: 5px 0; font-size: 12px; color: #007bff;\">Score: 0.840</p>\n",
       "       </div>\n",
       "       </div><div style=\"display: flex; flex-wrap: wrap; gap: 10px;\"></div>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# taken from: https://github.com/elastic/elasticsearch-py/blob/main/elasticsearch/helpers/vectorstore/_utils.py#L39\n",
    "def maximal_marginal_relevance(\n",
    "    query_embedding: List[float],\n",
    "    embedding_list: List[List[float]],\n",
    "    lambda_mult: float = 0.5,\n",
    "    k: int = 4,\n",
    ") -> List[int]:\n",
    "    query_embedding_arr = np.array(query_embedding)\n",
    "\n",
    "    if min(k, len(embedding_list)) <= 0:\n",
    "        return []\n",
    "    if query_embedding_arr.ndim == 1:\n",
    "        query_embedding_arr = np.expand_dims(query_embedding_arr, axis=0)\n",
    "    similarity_to_query = _cosine_similarity(query_embedding_arr, embedding_list)[0]\n",
    "    most_similar = int(np.argmax(similarity_to_query))\n",
    "    idxs = [most_similar]\n",
    "    selected = np.array([embedding_list[most_similar]])\n",
    "    while len(idxs) < min(k, len(embedding_list)):\n",
    "        best_score = -np.inf\n",
    "        idx_to_add = -1\n",
    "        similarity_to_selected = _cosine_similarity(embedding_list, selected)\n",
    "        for i, query_score in enumerate(similarity_to_query):\n",
    "            if i in idxs:\n",
    "                continue\n",
    "            redundant_score = max(similarity_to_selected[i])\n",
    "            equation_score = (\n",
    "                lambda_mult * query_score - (1 - lambda_mult) * redundant_score\n",
    "            )\n",
    "            if equation_score > best_score:\n",
    "                best_score = equation_score\n",
    "                idx_to_add = i\n",
    "        idxs.append(idx_to_add)\n",
    "        selected = np.append(selected, [embedding_list[idx_to_add]], axis=0)\n",
    "    return idxs\n",
    "\n",
    "\n",
    "mmr_indices = maximal_marginal_relevance(\n",
    "    query_embedding=query_vector,\n",
    "    embedding_list=[result[\"image_vector\"] for result in search_results],\n",
    "    lambda_mult=0.5,\n",
    "    k=100,\n",
    ")\n",
    "\n",
    "reranked_results = [search_results[i] for i in mmr_indices]\n",
    "display_images(reranked_results, \"Reranked Results (MMR)\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
